package edu.berkeley.nlp.math;

import edu.berkeley.nlp.mapper.AsynchronousMapper;
import edu.berkeley.nlp.mapper.SimpleMapper;
import edu.berkeley.nlp.util.Logger;
import edu.berkeley.nlp.util.CallbackFunction;

import java.util.Collection;
import java.util.List;
import java.util.ArrayList;

/**
 * User: aria42
 * Date: Mar 10, 2009
 */
public class OldStochasticObjectiveOptimizer<I> {

  Collection<I> items;
  List<? extends ObjectiveItemDifferentiableFunction<I>> itemFns;
  Regularizer regularizer;
  double initAlpha = 0.5;
  double upAlphaMult = 1.1;
  double downAlphaMult = 0.5;
  Object weightLock = new Object();
  double[] weights;
  double alpha ;
  CallbackFunction iterDoneCallback;

  public OldStochasticObjectiveOptimizer(double initAlpha, double upAlphaMult, double downAlphaMult)
  {
    this.initAlpha = initAlpha;
    this.upAlphaMult = upAlphaMult;
    this.downAlphaMult = downAlphaMult;
  }

  public void setIterationCallback(CallbackFunction iterDoneCallback) {
    this.iterDoneCallback = iterDoneCallback;
  }

  class Mapper implements SimpleMapper<I> {
    double val = 0.0;
    ObjectiveItemDifferentiableFunction<I> itemFn;
    Mapper(ObjectiveItemDifferentiableFunction<I> itemFn) {
      this.itemFn = itemFn;  
    }
    public void map(I elem) {
      double[] localWeights;
      synchronized (weightLock) {
        localWeights = DoubleArrays.clone(weights);
      }
      double[] localGrad = new double[dimension()];
      itemFn.setWeights(localWeights);
      val += itemFn.update(elem,localGrad);
      val += regularizer.update(localWeights,localGrad,1.0/items.size());
      synchronized (weightLock) {
        DoubleArrays.addInPlace(weights,localGrad, -alpha);
      }
    }
  }

  private double doIter() {
    List<Mapper> mappers = new ArrayList<Mapper>();
    for (ObjectiveItemDifferentiableFunction<I> itemFn : itemFns) {
      mappers.add(new Mapper(itemFn));
    }
    AsynchronousMapper.doMapping(items,mappers);
    double val = 0.0;
    for (Mapper mapper : mappers) {
      val += mapper.val;
    }
    return val;
  }

  public double[] minimize(double[] initWeights,
                           int numIters,
                           Collection<I> items,
                           List<? extends ObjectiveItemDifferentiableFunction<I>> itemFns,
                           Regularizer regularizer)
  {
    this.items = items;
    this.itemFns = itemFns;
    this.regularizer = regularizer;
    alpha = initAlpha;
    weights = DoubleArrays.clone(initWeights);
    double lastVal = Double.POSITIVE_INFINITY;    
    for (int iter = 0; iter < numIters; iter++) {
      double val = doIter();
      alpha *= (val < lastVal ? upAlphaMult : downAlphaMult);
      lastVal = val;
      Logger.logs("[StochasticObjectiveOptimizer] Ended Iteration %d with value %.5f",iter+1,val);
      if (iterDoneCallback != null) {
       iterDoneCallback.callback(iter,weights,val,alpha);
      }
    }
    return weights;
  }

  public int dimension() {
    return itemFns.get(0).dimension();
  }
}
